{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Matrix-factorization-based Cognitive Diagnosis model (MCD)\n",
    "\n",
    "This notebook will show you how to train and use the MCD.\n",
    "First, we will show how to get the data (here we use a0910 as the dataset).\n",
    "Then we will show how to train a MCD and perform the parameters persistence.\n",
    "At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n",
    "\n",
    "The script version could be found in [MCD.py](MCD.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Data Preparation\n",
    "\n",
    "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](prepare_dataset.ipynb)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "data": {
      "text/plain": "   user_id  item_id  score\n0     1615    12977      1\n1      782    13124      0\n2     1084    16475      0\n3      593     8690      0\n4      127    14225      1",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>user_id</th>\n      <th>item_id</th>\n      <th>score</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>1615</td>\n      <td>12977</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>782</td>\n      <td>13124</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>1084</td>\n      <td>16475</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>593</td>\n      <td>8690</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>127</td>\n      <td>14225</td>\n      <td>1</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the data from files\n",
    "import pandas as pd\n",
    "\n",
    "train_data = pd.read_csv(\"../../data/a0910/train.csv\")\n",
    "valid_data = pd.read_csv(\"../../data/a0910/valid.csv\")\n",
    "test_data = pd.read_csv(\"../../data/a0910/test.csv\")\n",
    "\n",
    "train_data.head(5)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "data": {
      "text/plain": "(186049, 25606, 55760)"
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data), len(valid_data), len(test_data)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "data": {
      "text/plain": "(<torch.utils.data.dataloader.DataLoader at 0x20965372340>,\n <torch.utils.data.dataloader.DataLoader at 0x20965372be0>,\n <torch.utils.data.dataloader.DataLoader at 0x209652a1e50>)"
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Transform data to torch Dataloader (i.e., batchify)\n",
    "# batch_size is set to 32\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "batch_size = 256\n",
    "def transform(x, y, z, batch_size, **params):\n",
    "    dataset = TensorDataset(\n",
    "        torch.tensor(x, dtype=torch.int64),\n",
    "        torch.tensor(y, dtype=torch.int64),\n",
    "        torch.tensor(z, dtype=torch.float)\n",
    "    )\n",
    "    return DataLoader(dataset, batch_size=batch_size, **params)\n",
    "\n",
    "train, valid, test = [\n",
    "    transform(data[\"user_id\"], data[\"item_id\"], data[\"score\"], batch_size)\n",
    "    for data in [train_data, valid_data, test_data]\n",
    "]\n",
    "train, valid, test"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training and Persistence"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.getLogger().setLevel(logging.INFO)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 727/727 [00:17<00:00, 42.18it/s]\n",
      "evaluating: 100%|██████████| 101/101 [00:00<00:00, 323.55it/s]\n",
      "Epoch 1: 100%|██████████| 727/727 [00:15<00:00, 47.61it/s]\n",
      "evaluating: 100%|██████████| 101/101 [00:00<00:00, 338.68it/s]\n",
      "INFO:root:save parameters to mcd.params\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] LogisticLoss: 0.635566\n",
      "[Epoch 0] auc: 0.669028, accuracy: 0.688901\n",
      "[Epoch 1] LogisticLoss: 0.570500\n",
      "[Epoch 1] auc: 0.721980, accuracy: 0.709482\n"
     ]
    }
   ],
   "source": [
    "from EduCDM import MCD\n",
    "\n",
    "cdm = MCD(4164, 17747, 100)\n",
    "\n",
    "cdm.train(train, valid, epoch=2)\n",
    "cdm.save(\"mcd.params\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Loading and Testing"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:load parameters from mcd.params\n",
      "evaluating: 100%|██████████| 218/218 [00:00<00:00, 318.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc: 0.722702, accuracy: 0.707694\n"
     ]
    }
   ],
   "source": [
    "cdm.load(\"mcd.params\")\n",
    "auc, accuracy = cdm.eval(test)\n",
    "print(\"auc: %.6f, accuracy: %.6f\" % (auc, accuracy))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}